#!/usr/bin/env python3

from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
import DonaldDuckDataset
from foolbox4attack import attackMethods, FoolboxAttack
import DonaldDuckConv
import DonaldDuckFunc
import  test
import numpy as np
import os
import tensorflow as tf

if __name__ == "__main__":
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    physical_devices = tf.config.list_physical_devices('GPU')
    try:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        assert tf.config.experimental.get_memory_growth(physical_devices[0])
    except:
        pass
    tf.random.set_seed(
        123
    )

    # you could change dataset and victim model by commenting directly on the code

    # MNIST, Fashion and victim model CNN-8
    # dataset=DonaldDuckDataset.Fashion(standardization=False)
    # dataset=DonaldDuckDataset.MNIST(standardization=False)
    # conv_layers_num = 5
    # init_filters = 32
    # cnn = DonaldDuckConv.DonaldDuckCNN(
    #      dataset,
    #      build_dir=False
    # )
    # cnn.setModel(
    #      conv_layers_num=conv_layers_num,
    #      filters=init_filters,
    #      kernel_size=(3,3)
    # )
    # cnn.load_model(
    #      weights_path=r'savedModels//'+ cnn.name+'.h5'
    # )

    # CIFAR and victim model VGG-16
    dataset=DonaldDuckDataset.CIFAR10(standardization=False)
    cnn = DonaldDuckConv.DonaldDuckVGG16(
        dataset,
        build_dir=False
    )
    cnn.setModel()
    cnn.load_model(
        weights_path=r'savedModels//' + cnn.name  + '.h5'
    )
#
    advNum=10000
    fa = FoolboxAttack(
        model=cnn,
        advNum=advNum
    )
    
    for ams in attackMethods:
        for am in attackMethods[ams]:
            epsilons = attackMethods[ams][am]['epsilon'][dataset.name]
            for epsilon in epsilons:
                adv_examples, imgs, labels=fa.create_adversarial_pattern(
                    attackMethod=attackMethods[ams][am]['method'],
                    attack_name=ams+'_'+am+'_'+str(epsilon)+'_'+str(advNum),
                    epsilons=epsilon
                )
                # print(fa.test_adv())
                print(round(np.mean(DonaldDuckFunc.cal_distance(adv_examples,imgs,lp=0)),4), end=' ')
                print(round(np.mean(DonaldDuckFunc.cal_distance(adv_examples,imgs,lp=1)),4), end=' ')
                print(round(np.mean(DonaldDuckFunc.cal_distance(adv_examples,imgs,lp=2)),4), end=' ')
                print(round(np.mean(DonaldDuckFunc.cal_distance(adv_examples,imgs,lp=np.inf)),4))
                fa.saveExamples(fa.model.name)